{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "c3753329-d35a-4353-94b5-5c3ce726ffe2",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import random\n",
    "import numpy as np\n",
    "import json\n",
    "import matplotlib.pyplot as plt\n",
    "import SimpleITK as sitk\n",
    "import nibabel as nib\n",
    "\n",
    "from monai.transforms import (\n",
    "    CenterSpatialCropd,\n",
    "    Compose,\n",
    "    EnsureChannelFirstd,\n",
    "    LoadImaged,\n",
    "    Resized,\n",
    "    SpatialPadd,\n",
    "    ScaleIntensityRangePercentilesd,\n",
    "    ToTensord,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "f752ff86-c24c-423b-9127-df34d60d86e8",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "/home/AR32500/AR32500/MyPapers/box-prompt-learning-VFM/src\n"
     ]
    }
   ],
   "source": [
    "os.chdir(os.path.dirname(os.path.dirname(os.getcwd())))\n",
    "print(os.getcwd())"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "12c25ae0-905d-4522-a06f-2964514bacaa",
   "metadata": {},
   "source": [
    "## Functions for preprocessing"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "a0cd4c13-bede-431e-b52f-3583134eb896",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_val_patient_idx(train_patient_idx=[], num_val_patients=10):\n",
    "    print('Creating a validation dataset with {} patients'.format(num_val_patients))\n",
    "    _val_patient_idx = random.sample(train_patient_idx, num_val_patients)\n",
    "    val_patient_idx = sorted(_val_patient_idx)\n",
    "    return val_patient_idx\n",
    "\n",
    "\n",
    "def get_train_test_patient_idx(data_dir, images_list, num_test_patients=10):\n",
    "    \"\"\"_summary_\n",
    "\n",
    "    Args:\n",
    "        data_dir (str): data directory\n",
    "        images_list (list): list of paths to each images (ending with .nii.gz)\n",
    "        num_test_patients (int, optional): number of test patients. Defaults to 10.\n",
    "\n",
    "    Returns:\n",
    "        train_patient_idx (list)\n",
    "        test_patient_idx (list)\n",
    "    \"\"\"\n",
    "    # If the data was already preprocessed, we take the patients idx separation already used\n",
    "    preprocessed_data_dir = os.path.join(os.path.dirname(data_dir), 'preprocessed')\n",
    "    \n",
    "    if os.path.exists(preprocessed_data_dir):\n",
    "        train_scan_info_path = os.path.join(preprocessed_data_dir, 'train', 'scan_info.json')\n",
    "        with open(train_scan_info_path) as f:\n",
    "            d = json.load(f)\n",
    "            train_patient_idx = list(d.keys())\n",
    "            \n",
    "        test_scan_info_path = os.path.join(preprocessed_data_dir, 'test', 'scan_info.json')\n",
    "        with open(test_scan_info_path) as f:\n",
    "            d = json.load(f)\n",
    "            test_patient_idx = list(d.keys())\n",
    "    else: \n",
    "        print('{} does not exist'.format(preprocessed_data_dir))\n",
    "        \n",
    "        # We choose test patients\n",
    "        patient_name_list = [os.path.basename(path).replace('.nii.gz', '') for path in images_list]\n",
    "        _test_patient_idx = random.sample(patient_name_list, num_test_patients)\n",
    "        test_patient_idx = sorted(_test_patient_idx)\n",
    "\n",
    "        train_patient_idx = [name for name in patient_name_list if name not in test_patient_idx]\n",
    "        \n",
    "    return train_patient_idx, test_patient_idx\n",
    "    \n",
    "    \n",
    "def get_train_val_test_list(data_dir, images_list, num_test_patients=10, num_val_patients=10):    \n",
    "    _train_patient_idx, test_patient_idx = get_train_test_patient_idx(data_dir, images_list, num_test_patients)\n",
    "    val_patient_idx = get_val_patient_idx(_train_patient_idx, num_val_patients)\n",
    "    \n",
    "    assert all(value in _train_patient_idx for value in val_patient_idx)\n",
    "    train_patient_idx = [idx for idx in _train_patient_idx if idx not in val_patient_idx]\n",
    "    \n",
    "    print('train patients:', len(list(train_patient_idx)))\n",
    "    print('val patients:', len(list(val_patient_idx)))\n",
    "    print('test patients:', len(list(test_patient_idx)))\n",
    "    \n",
    "    return train_patient_idx, val_patient_idx, test_patient_idx"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "7e721ba8-c7af-4b27-aa1a-c2478fb9a1af",
   "metadata": {},
   "outputs": [],
   "source": [
    "def create_sam_directories(base_dir, type='slice'):\n",
    "    \"\"\"Create directories\"\"\"\n",
    "    dir_paths = {}\n",
    "    if type == 'slice':\n",
    "        for dataset in ['train', 'val', 'test']:\n",
    "            for data_type in ['2d_images', '2d_masks']:\n",
    "                # Construct the directory path\n",
    "                dir_path = os.path.join(base_dir, f'{dataset}_{data_type}')\n",
    "                dir_paths[f'{dataset}_{data_type}'] = dir_path\n",
    "                # Create the directory\n",
    "                os.makedirs(dir_path, exist_ok=True)\n",
    "    elif type == 'volume':\n",
    "        for dataset in [\"imagesTr\", \"labelsTr\"]:\n",
    "            # Construct the directory path\n",
    "            dir_path = os.path.join(base_dir, dataset)\n",
    "            dir_paths[dataset] = dir_path\n",
    "            # Create the directory\n",
    "            os.makedirs(dir_path, exist_ok=True)        \n",
    "    return dir_paths\n",
    "\n",
    "\n",
    "def ceil_to_multiple_of_5(n):\n",
    "    return 5 * np.ceil(n / 5.)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "27db9ade-3af0-444b-97d2-7cf3f940b292",
   "metadata": {},
   "source": [
    "# Preprocessing dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "4ef464a3-72a5-40e5-a19c-f4fccdf9b117",
   "metadata": {},
   "outputs": [],
   "source": [
    "data_dir = '/data/users/melanie/data/'\n",
    "dataset_name = 'CAMUS_public'\n",
    "file_type = '_2CH_ED'\n",
    "suffix = '_niigz'\n",
    "\n",
    "frac_test_patients = 0.2\n",
    "frac_val_patients = 0.1\n",
    "\n",
    "remove_background_slices = True\n",
    "class_list = [1, 2, 3]\n",
    "\n",
    "# For _512 data\n",
    "crop_pad_size = (512, 512)\n",
    "new_size = (512, 512) #(no resizing)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "e5b5bcd8-3def-488b-bf93-ed0aad805e19",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "500\n",
      "/data/users/melanie/data/CAMUS_public/raw/preprocessed does not exist\n",
      "Creating a validation dataset with 50 patients\n",
      "train patients: 350\n",
      "val patients: 50\n",
      "test patients: 100\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "{'train_2d_images': '/data/users/melanie/data/CAMUS_public/preprocessed_sam/train_2d_images',\n",
       " 'train_2d_masks': '/data/users/melanie/data/CAMUS_public/preprocessed_sam/train_2d_masks',\n",
       " 'val_2d_images': '/data/users/melanie/data/CAMUS_public/preprocessed_sam/val_2d_images',\n",
       " 'val_2d_masks': '/data/users/melanie/data/CAMUS_public/preprocessed_sam/val_2d_masks',\n",
       " 'test_2d_images': '/data/users/melanie/data/CAMUS_public/preprocessed_sam/test_2d_images',\n",
       " 'test_2d_masks': '/data/users/melanie/data/CAMUS_public/preprocessed_sam/test_2d_masks'}"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# We create create 6 folders in the subfolder 'preprocessed_sam' of the directory associated with the dataset: \n",
    "# 2 folders ('2d_images' and '2d_masks') for each cut ['train', 'val', 'test'].\n",
    "\n",
    "raw_data_dir = os.path.join(data_dir, dataset_name, 'raw', 'database_nifti')\n",
    "base_dir_slice = os.path.join(data_dir, dataset_name, 'preprocessed_sam')\n",
    "\n",
    "patient_name_list = sorted(os.listdir(os.path.join(data_dir, dataset_name, 'raw', 'database_nifti')))\n",
    "print(len(patient_name_list))\n",
    "\n",
    "# We get train, val and test patient names\n",
    "num_test_patients = int(ceil_to_multiple_of_5(len(patient_name_list) * frac_test_patients))\n",
    "num_val_patients = int(len(patient_name_list) * frac_val_patients)\n",
    "train_patient_idx, val_patient_idx, test_patient_idx = get_train_val_test_list(raw_data_dir, patient_name_list, num_test_patients, num_val_patients)\n",
    "\n",
    "# Create directories to save preprocessed volumes and slices\n",
    "dir_paths_slice = create_sam_directories(base_dir_slice, type=\"slice\")\n",
    "\n",
    "dir_paths_slice"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "54522468-084a-4199-b833-606663f6ffb4",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "350 50 100\n"
     ]
    }
   ],
   "source": [
    "# val_patient_idx = ['patient0014', 'patient0016', 'patient0040', 'patient0057', 'patient0059', 'patient0068', 'patient0069', 'patient0076', 'patient0085', 'patient0091', 'patient0099', 'patient0107', 'patient0135', 'patient0156', 'patient0165', 'patient0187', 'patient0194', 'patient0203', 'patient0213', 'patient0225', 'patient0236', 'patient0237', 'patient0239', 'patient0251', 'patient0258', 'patient0259', 'patient0281', 'patient0293', 'patient0302', 'patient0305', 'patient0321', 'patient0324', 'patient0326', 'patient0329', 'patient0340', 'patient0344', 'patient0347', 'patient0356', 'patient0370', 'patient0379', 'patient0385', 'patient0414', 'patient0422', 'patient0431', 'patient0448', 'patient0458', 'patient0460', 'patient0475', 'patient0485', 'patient0494']\n",
    "# test_patient_idx = ['patient0011', 'patient0022', 'patient0029', 'patient0030', 'patient0034', 'patient0037', 'patient0041', 'patient0051', 'patient0055', 'patient0058', 'patient0060', 'patient0062', 'patient0066', 'patient0072', 'patient0073', 'patient0074', 'patient0089', 'patient0096', 'patient0097', 'patient0101', 'patient0113', 'patient0114', 'patient0115', 'patient0120', 'patient0124', 'patient0125', 'patient0134', 'patient0138', 'patient0146', 'patient0150', 'patient0158', 'patient0161', 'patient0162', 'patient0173', 'patient0183', 'patient0189', 'patient0190', 'patient0210', 'patient0212', 'patient0221', 'patient0227', 'patient0228', 'patient0229', 'patient0233', 'patient0244', 'patient0252', 'patient0255', 'patient0257', 'patient0261', 'patient0266', 'patient0276', 'patient0277', 'patient0292', 'patient0295', 'patient0296', 'patient0303', 'patient0318', 'patient0330', 'patient0334', 'patient0335', 'patient0336', 'patient0338', 'patient0341', 'patient0343', 'patient0349', 'patient0350', 'patient0352', 'patient0354', 'patient0358', 'patient0365', 'patient0369', 'patient0372', 'patient0373', 'patient0380', 'patient0386', 'patient0390', 'patient0393', 'patient0400', 'patient0404', 'patient0407', 'patient0413', 'patient0423', 'patient0428', 'patient0430', 'patient0438', 'patient0439', 'patient0442', 'patient0443', 'patient0444', 'patient0445', 'patient0456', 'patient0464', 'patient0473', 'patient0476', 'patient0477', 'patient0486', 'patient0490', 'patient0491', 'patient0495', 'patient0499']\n",
    "# train_patient_idx = [f for f in patient_name_list if (f not in val_patient_idx) and (f not in test_patient_idx)]\n",
    "# print(len(train_patient_idx), len(val_patient_idx), len(test_patient_idx))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "634a706f-5d57-4fb4-add2-cf3619914d2f",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Added ScaleIntensityRangePercentilesd on Feb 14th 2024\n",
    "transforms = Compose([\n",
    "    LoadImaged(keys=[\"img\", \"label\"]),  # load .nii or .nii.gz files\n",
    "    EnsureChannelFirstd(keys=['img', 'label']),\n",
    "    CenterSpatialCropd(keys=['img', 'label'], roi_size=crop_pad_size), \n",
    "    SpatialPadd(keys=[\"img\", \"label\"], spatial_size=crop_pad_size), # pad if size smaller than 512 x 512 --> get size 512 x 512 (since already cropped)\n",
    "    Resized(keys=[\"img\", \"label\"], spatial_size=new_size, mode=['bilinear', 'nearest']),\n",
    "    ScaleIntensityRangePercentilesd(keys=[\"img\"], \n",
    "                     lower=0.5,  upper=99.5,  # This should call the percentile_scale function to get the 95th percentile\n",
    "                      b_min=0, b_max=255, clip=True),\n",
    "    ToTensord(keys=[\"img\", \"label\"])\n",
    "    ])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fe5bf051-d4c0-4d3d-86d3-e6b48b246f5e",
   "metadata": {},
   "outputs": [],
   "source": [
    "for patient_name in patient_name_list:\n",
    "    img_path = os.path.join(data_dir, dataset_name, 'raw', 'database_nifti', patient_name, patient_name + file_type + '.nii.gz') \n",
    "    mask_path = os.path.join(data_dir, dataset_name, 'raw', 'database_nifti', patient_name, patient_name + file_type + '_gt.nii.gz') \n",
    "\n",
    "    data_dict = transforms({'img': img_path, 'label': mask_path})\n",
    "    img = data_dict['img'][0, :, :].astype(np.uint8)\n",
    "    mask = data_dict['label'][0, :, :].astype(np.uint8)\n",
    "\n",
    "    print(patient_name, nib.load(img_path).get_fdata().shape)\n",
    "\n",
    "    # Optionally remove non-informative slices\n",
    "    #if remove_background_slices and np.all(mask_2d == 0):\n",
    "    unique_labels = np.unique(mask)\n",
    "    if remove_background_slices and not all(label in unique_labels for label in class_list):\n",
    "        print('not all label classes: {}'.format(patient_name))\n",
    "        pass\n",
    "\n",
    "    else:\n",
    "        # Select appropriate directories\n",
    "        if patient_name in train_patient_idx:  # Training\n",
    "            img_dir = dir_paths_slice['train_2d_images']\n",
    "            mask_dir = dir_paths_slice['train_2d_masks']\n",
    "        elif patient_name in val_patient_idx:  # Validation\n",
    "            img_dir = dir_paths_slice['val_2d_images']\n",
    "            mask_dir = dir_paths_slice['val_2d_masks']\n",
    "        else:  # Testing\n",
    "            img_dir = dir_paths_slice['test_2d_images']\n",
    "            mask_dir = dir_paths_slice['test_2d_masks']\n",
    "    \n",
    "        # Define the output paths\n",
    "        filename = os.path.basename(img_path)\n",
    "        img_slice_path = os.path.join(img_dir, filename)\n",
    "        mask_slice_path = os.path.join(mask_dir, filename)\n",
    "    \n",
    "        sitk.WriteImage(sitk.GetImageFromArray(img), img_slice_path)\n",
    "        sitk.WriteImage(sitk.GetImageFromArray(mask), mask_slice_path)\n",
    "    \n",
    "        fig = plt.figure(figsize=(10, 10))\n",
    "        ax = fig.add_subplot(121)\n",
    "        ax.imshow(img)\n",
    "        ax = fig.add_subplot(122)\n",
    "        ax.imshow(mask)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4fc808e0-20b5-4390-a299-87385942c00f",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "py310",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
